import datetime as dt
import re
import traceback
from functools import wraps
from typing import Any
import connectorx as cx
import pandas as pd
from croniter import croniter
from pandas.api.types import infer_dtype
from pangres import upsert
from sqlalchemy import create_engine, VARCHAR, TIMESTAMP, Float, Integer, BIGINT, text
from shares import utils
from shares.models import Query
from shares.utils import load_config, store_config

def table_exists(table_name: str, con) -> bool:
    """
    检查表是否已经存在
    :param table_name:
    :param con:
    :return:
    """
    df = pd.read_sql(
        "SELECT to_regclass('public.{table_name}');".format(
            table_name=table_name), con)
    return df['to_regclass'].iloc[0] is not None





def change_key(table_name, engine, unique_cols) -> None:
    """
    添加唯一索引和自增id,自增id用于connector-x查询加速partition,唯一索引用于多次插入数据去重
    :param table_name: 表名
    :param unique_cols:
    :param engine:
    :return:
    """
    with engine.connect() as connection:
        result = connection.execute(
            text("""
            ALTER TABLE public.{table}
              ADD COLUMN  IF NOT EXISTS  id serial,
              DROP CONSTRAINT "{table}_pkey",
              ADD CONSTRAINT "{table}_pkey" PRIMARY KEY ("id"),
              ADD CONSTRAINT "{table}_unique_key" UNIQUE ({unique_cols});
            """.format(table=table_name, unique_cols='"' + '","'.join(unique_cols) + '"')))
        return result




def normalize_code(codes) -> list:
    """规范化代码"""
    nums = [re.findall("\\d{6}", code)[0] for code in codes]
    codes = []
    for num in nums:
        if num[0] in ['0', '1', '3']:
            codes.append(num + ".SZ")
        elif num[0] in ['8', '4']:
            codes.append(num + ".BJ")
        else:
            codes.append(num + ".SH")
    return codes


def normal_cols(df: pd.DataFrame) -> pd.DataFrame:
    """
    日期，股票代码自动转换格式
    :param df:
    :return:
    """
    df_ret = df.copy()
    for i, j in zip(df.columns, df.dtypes):
        if str(i) in ['公告日期', '实际公告日期', '报告期', 'date', 'updated', 'time', '更新时间'] or '日期' in str(i):
            if len(df_ret[i].iloc[0]) == 5:
                df_ret[i] = dt.datetime.now().strftime("%Y") + "-" + df_ret[i]
            df_ret[i] = pd.to_datetime(df_ret[i])
        elif str(i) in ['code', '证券代码'] or '代码' in str(i):
            if str(j) != 'object':
                df_ret[i] = df_ret[i].astype(str).str.zfill(6)
            df_ret[i] = normalize_code(df_ret[str(i)].tolist())
    return df_ret


def rename_col_names(cols) -> list:
    """
    列名中替换掉空格和圆括号
    :param cols:
    :return:
    """
    return [col.replace(" ", "_").replace("(", "$").replace(")", "$") for col in cols]


def dtypes_normal(df: pd.DataFrame) -> dict:
    """
    根据df的列类型，生成类型词典
    :param df:
    :return:
    """
    type_dict = {}
    index_types = [d for d in df.index.dtypes] if df.index.inferred_type == 'mixed' else [df.index.inferred_type] # type: ignore
    for i, j in zip(list(df.columns) + list(df.index.names), list(df.dtypes) + index_types):
        if i in df.columns:
            dtype = infer_dtype(df[i].values, skipna=True)
        else:
            dtype = infer_dtype(df.index.get_level_values(i), skipna=True)
        if 'date' in dtype or 'time' in dtype:
            type_dict.update({i: TIMESTAMP()})
        elif i == 'ts_code':
            type_dict.update({i: VARCHAR(20)})
        elif "string" == dtype:
            if i in df.columns:
                len_series = df[i].str.len()
            else:
                len_series = df.index.get_level_values(i).str.len()
            type_dict.update({i: VARCHAR(length=max(255, len_series.max(skipna=True) * 2))})
        elif "float" in dtype or 'decimal' in dtype:
            type_dict.update({i: Float(precision=4, asdecimal=True)})
        elif "int32" in str(j):
            type_dict.update({i: Integer()})
        elif "int64" in str(j):
            type_dict.update({i: BIGINT()})
    return type_dict


class Store:
    def __init__(self, con_url: str="", debug: bool = True, print_log: bool = True):
        """
        数据保存对象
        :param con_url: 数据库连接url
        :param debug: debug=True crontab表达式不生效
        :param print_log: print_log=True 执行的时候打印日志
        使用方法举例：
        store=Store(con_url="postgresql+psycopg2://username:password@host:port/db",debug=False)
        """
        self.config = load_config()
        if con_url is None or len(con_url) < 1:
            self.con_url = self.config.get("con_url")
        else:
            self.con_url = con_url
            self.config['con_url'] = con_url
        self.debug = debug
        self.print_log = print_log
        self.engine = create_engine(self.con_url) # type: ignore



    def store_config(self) -> None:
        """
        存储配置
        """
        store_config(self.config)

    def set_con_url(self, con_url: str) -> None:
        """
        设置连接URL
        :param con_url: 数据库连接URL
        """
        config = load_config()
        if con_url is None or len(con_url) < 1:
            self.con_url = config.get("con_url")
        else:
            self.con_url = con_url
            config['con_url'] = con_url
            store_config(config)

    def save_df(self, df: pd.DataFrame, table_name: str, unique_cols: list = []) -> None:
        """
        保存DataFrame
        :param df: 要保存的DataFrame
        :param table_name: 表名
        :param unique_cols: 唯一索引列
        """
        exists = table_exists(table_name, self.engine)
        if df is not None and df.shape[0] > 0:
            chunksize = int(32766 / df.shape[1] / 16)
            df = df.set_index(unique_cols, drop=True)
            df.rename(
                columns={col: col.replace("%", "百分比").replace("(", "").replace(")", "").replace("（",
                                                                                                   '').replace(
                    "）", "") for col in
                    df.columns}, inplace=True)
            df['updated_at'] = dt.datetime.now()
            upsert(con=self.engine,
                   df=df,
                   table_name=table_name,
                   if_row_exists='update',
                   dtype=dtypes_normal(df), chunksize=chunksize, add_new_columns=True)
            if not exists:
                change_key(table_name, self.engine, unique_cols)



    def save(self, table_name: str, unique_cols: list = [], crontab: str = "* * * * *")->Any:
        """
          保存数据注解，注解到返回值是个DataFrame的函数
          :param table_name:表名
          :param unique_cols: 唯一索引列
          :param crontab: crontab表达式，debug=False,在执行的时候会检查当前时间是否复合crontab表达式，不符合就跳过不执行
          :return:
        """
        debug = self.debug
        print_log = self.print_log

        def _memoize(fn):
            @wraps(fn)  # 自动复制函数信息
            def __memoize(*args, **kwargs):
                try:
                    if croniter.match(crontab, dt.datetime.now()) or debug:
                        table = table_name if table_name is not None else fn.__name__
                        exists = table_exists(table, self.engine)
                        if print_log:
                            print("------------------------------------------------------------------------------")
                            print(table + " start: " + dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
                            print("位置参数:", args, "命名参数:", kwargs)
                        df = fn(*args, **kwargs)
                        if df is not None and df.shape[0] > 0:
                            chunksize = int(32766 / df.shape[1] / 2)
                            df = df.set_index(unique_cols, drop=True)
                            df.rename(
                                columns={col: col.replace("%", "百分比").replace("(", "").replace(")", "").replace("（",
                                                                                                                   '').replace(
                                    "）", "") for col in
                                    df.columns}, inplace=True)
                            df['updated_at'] = dt.datetime.now()
                            upsert(con=self.engine,
                                   df=df,
                                   table_name=table,
                                   schema='public',
                                   if_row_exists='update',
                                   dtype=dtypes_normal(df), chunksize=chunksize, add_new_columns=True)
                            if not exists:
                                change_key(table, self.engine, unique_cols)
                        if print_log:
                            print(table + "   end: " + dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
                        return df
                except Exception as e:
                    print(e)
                    traceback.print_exc()

            return __memoize

        return _memoize

   
    def sync_query(self, query: Query, cn_column_names=False, connectorx=True) -> pd.DataFrame:
        """
        根据API查询对象，从数据库返回查询的数据，但是有些函数查询值和返回值不同,这个方法只适用部分API
        :param query: 例子 api.stock_list()
        :param cn_column_names:是否返回中文列名
        :param connectorx: 是否使用connectorx加载
        :param partition_num: 多线程分区加载，分区数
        :return: pd.DataFrame
        """
        table = query.method
        sql = "select * from {table}".format(table=table)
        if len(query.params) > 0:
            sql = sql + " where " + (" and ".join([k + "=" + "'" + v + "'" for k, v in query.params.items()]))
        if connectorx:
            df = cx.read_sql(self.con_url.replace("+psycopg2:", ":"), sql) # type: ignore
        else:
            df = pd.read_sql(sql, self.engine)
        if df.shape[0] > 0: # type: ignore
            return utils.change_type_and_column_names(df, query, cn_column_names, change_types=False) # type: ignore
        else:
            return df # type: ignore

    def sync_sql(self, sql, cn_column_names=False, query = None, connectorx=True) -> pd.DataFrame:
        """
        执行sql查询数据库
        :param sql: 原始sql
        :param cn_column_names: 是否返回中文列名,True 需要传入对应的query对象来从query对象找到列名的中英文词典
        :param query:api.stock_list()
        :param partition_num: 线程分区加载，分区数
        :param connectorx: 是否启动 connectorx 加速数据load
        :return: pd.DataFrame
        """
        if connectorx:
            df = cx.read_sql(self.con_url.replace("+psycopg2:", ":"), sql) # type: ignore
        else:
            df = pd.read_sql(sql, self.engine)
        if df.shape[0] > 0 and cn_column_names and query is not None: # type: ignore
            return utils.change_type_and_column_names(df, query, cn_column_names, change_types=False) # type: ignore
        else:
            return df # type: ignore

    def sync_execute(self, sql) -> Any:
        """
        执行SQL
        :param sql:
        :type sql:
        :return: ResultProxy
        :rtype: ResultProxy
        """
        with self.engine.connect() as connection:
            result = connection.execute(text(sql))
            return result


__all__ = ["Store", 'normalize_code', 'normal_cols', 'dtypes_normal']
